Abstract

This study evaluates multiple machine learning models for classifying feedback (1 for positive, -1 for negative) using neural recordings from 18 experimental sessions across four mice, obtained from Steinmetz et al. (2019). Exploratory Data Analysis (EDA) was conducted to examine spike activity patterns, session trends, and behavioral variations. Dimensionality reduction techniques (PCA, t-SNE) and clustering methods were applied to explore neural activation structures. A structured data pipeline was developed to preprocess raw spike data into a feature-rich format for classification.

Four predictive models (Logistic Regression, K-Nearest Neighbors (KNN), Support Vector Machine (SVM), and XGBoost) were evaluated using accuracy, confusion matrices, precision, recall, F1-score, and ROC-AUC curves. XGBoost achieved the highest accuracy (72.47%), outperforming KNN (71.68%), Logistic Regression (71.29%), and SVM (71.09%). However, all models were affected by severe class imbalance, leading to low sensitivity in detecting negative feedback (-1). XGBoost, while the most balanced, had a sensitivity of 10.91%, specificity of 94.48%, and an AUC score of 0.6127, indicating moderate classification performance. McNemar’s test confirmed significant misclassification biases, highlighting the dataset’s imbalance challenge.

Despite XGBoost emerging as the best-performing model, further refinements such as class balancing strategies, hyperparameter tuning, and deep learning approaches may improve classification of negative feedback and enhance model generalization. These findings underscore the challenges of classifying neural activity data and the importance of addressing imbalance in machine learning applications for neuroscience research.


Introduction

Recent neuroscience research has revealed that decision-making and engagement emerge from distributed neural activity across multiple brain regions rather than from single localized areas. This finding, highlighted in studies like “Distributed coding of choice, action, and engagement across the mouse brain,” underscores the complex relationship between neural signals and behavioral outcomes.

My report presents a predictive model that interprets behavioral states—including choice selection, engagement level, and movement initiation—from comprehensive neural recordings. I analyzed neural spike trains collected from diverse brain regions across multiple experimental sessions, offering insights into how different areas contribute to behavior and whether these contributions remain consistent.

I implemented a three-phase methodology: (1) exploratory data analysis to characterize the dataset and identify neural correlates of behavior; (2) data integration to extract shared neural patterns while accounting for session variability; and (3) predictive modeling to infer behavioral states from neural activity. This approach aims to determine how reliably distributed neural activity can predict an animal’s choices, engagement, and actions.


Exploratory Data Analysis

The exploratory analysis of the dataset, which consists of multi-session neuronal recordings from mice performing a visual discrimination task, reveals key insights into the structure and variability of neural activity. Each session contains 40 recorded neurons, with trials ranging from 114 to over 289,000, highlighting the extensive scope of the dataset. The analysis of spike activity per trial shows that while most trials exhibit low to moderate firing rates, certain trials experience high bursts of activity, indicating potential task-related neuronal engagement. A session-wise comparison reveals notable heterogeneity in firing rates, with Session 8 showing the highest mean spike count (1.66 spikes/trial) and the greatest variability (standard deviation = 3.10), suggesting session-specific or stimulus-related differences. Moreover, at least 25% of trials exhibit zero spike activity, indicating periods of inactivity or non-engagement. A temporal analysis across trials further highlights fluctuations in neural responses, suggesting a possible correlation between stimulus conditions and neuronal firing.

# Initialize the session_summary_data tibble
session_summary_data = tibble(
  mouse_id = rep('Mouse_Placeholder', 18),
  session_id = rep(0, 18),
  session_date = rep('YYYY-MM-DD', 18),
  total_brain_regions = rep(0, 18),
  total_neurons = rep(0, 18),
  total_trials = rep(0, 18),
  avg_success_rate = rep(0, 18)
)

for (i in 1:18) {
  current_session = session[[i]]
  session_summary_data[i, ] = tibble(
    mouse_id = current_session$mouse_name,
    session_id = i,
    session_date = current_session$date_exp,
    total_brain_regions = length(unique(current_session$brain_area)),
    total_neurons = dim(current_session$spks[[1]])[1],
    total_trials = length(current_session$feedback_type),
    avg_success_rate = mean(current_session$feedback_type + 1) / 2
  )
}

head(session_summary_data)
## # A tibble: 6 × 7
##   mouse_id  session_id session_date total_brain_regions total_neurons
##   <chr>          <dbl> <chr>                      <dbl>         <dbl>
## 1 Cori               1 2016-12-14                     8           734
## 2 Cori               2 2016-12-17                     5          1070
## 3 Cori               3 2016-12-18                    11           619
## 4 Forssmann          4 2017-11-01                    11          1769
## 5 Forssmann          5 2017-11-02                    10          1077
## 6 Forssmann          6 2017-11-04                     5          1169
## # ℹ 2 more variables: total_trials <dbl>, avg_success_rate <dbl>

This code constructs a summary table (session_summary_data) that consolidates key statistics for each of the 18 experimental sessions. It first initializes a placeholder tibble with columns for mouse ID, session ID, session date, total brain regions, total neurons, total trials, and average success rate. Then, using a loop, it iterates through the session list, extracting relevant details from each session dataset—such as the mouse name, date of the experiment, the number of recorded brain regions and neurons, total trials conducted, and the session’s average success rate—before storing them in the tibble. The final output is a structured dataset that provides an overview of all sessions, facilitating further analysis and visualization.

This session_summary_data tibble contains key information about each session. It includes the mouse identifier (mouse_id), the session number (session_id), and the date of the experiment (session_date). Additionally, it records the number of unique brain regions activated during the session (total_brain_regions), the total number of neurons recorded (total_neurons), and the total number of trials conducted (total_trials). The last column, avg_success_rate, represents the proportion of trials in which the mouse received positive feedback. This tibble provides a structured summary of each session and will be useful for further analysis and visualization.

Data Content and Context

For the data of each session, there are 8 variables.

##                Length Class  Mode     
## contrast_left  114    -none- numeric  
## contrast_right 114    -none- numeric  
## feedback_type  114    -none- numeric  
## mouse_name       1    -none- character
## brain_area     734    -none- character
## date_exp         1    -none- character
## spks           114    -none- list     
## time           114    -none- list

These 8 variables and their meanings are:

  • contrast_left is the contrast of the left stimulus
  • contrast_right is the contrast of the right stimulus
  • feedback_type is the feedback for the mice where 1 is positive feedback and -1 for negative feedback
  • mouse_name is the mouse name (Cori, Forssmann, Hench, or Lederberg)
  • brain_area is the brain area that is activated
  • date_exp represents the date the experiments took place.
  • spks represents the number of spikes in the visual cortex over time
  • time represents the centers of the time bins

Mean Activated Brain Regions Per Mouse

session_summary_data %>%
  group_by(mouse_id) %>%
  summarise(avg_brain_regions = mean(total_brain_regions)) %>%
  ggplot(aes(x = mouse_id, y = avg_brain_regions, fill = mouse_id)) +
  geom_bar(stat = 'identity') +
  labs(title = 'Mean Activated Brain Regions Per Mouse', x = "Average Brain Regions", y = "Mouse Name", fill = "Mouse ID")

This bar chart displays the mean number of activated brain regions per mouse across different sessions. The x-axis represents the mouse names (Cori, Forssmann, Hench, and Lederberg), while the y-axis indicates the average number of activated brain regions. Each bar represents the mean number of unique brain regions that were recorded as active for a given mouse across all sessions.

From the chart, we can see that Hench has the highest average number of activated brain regions, suggesting that this mouse experienced the most widespread neural activation during experiments. Lederberg follows closely behind, while Forssmann and Cori have relatively fewer activated brain regions on average. This variation might indicate differences in individual neural activity or experimental conditions across the mice.

Mean Neurons Activated Per Mouse

session_summary_data %>%
  group_by(mouse_id) %>%
  summarise(avg_neurons = mean(total_neurons)) %>%
  ggplot(aes(x = mouse_id, y = avg_neurons, fill = mouse_id)) +
  geom_bar(stat = 'identity') +
  labs(title = 'Mean Neurons Activated Per Mouse', x = "Average Neurons", y = "Mouse Name", fill = "Mouse ID")

This bar chart represents the mean number of neurons activated per mouse across different sessions. The x-axis displays the mouse names (Cori, Forssmann, Hench, and Lederberg), while the y-axis indicates the average number of neurons activated during the sessions. Each bar corresponds to a specific mouse, showing the mean number of neurons recorded in their experiments.

From the chart, Forssmann has the highest average neuron activation, followed by Hench, while Cori and Lederberg have lower activation levels. This trend suggests that Forssmann consistently exhibited the most widespread neural activity across its recorded sessions. Interestingly, this differs from the previous chart on brain regions activated per mouse, where Hench had the highest value. This could imply that Forssmann has fewer brain regions activated on average but with higher neuron density in those regions, whereas Hench may have more distributed but less densely activated neural activity.

Success Rate by Mouse

session_summary_data %>%
  group_by(mouse_id) %>%
  summarise(mean_success = mean(avg_success_rate)) %>%
  ggplot(aes(x = mouse_id, y = mean_success, fill = mouse_id)) +
  geom_bar(stat = 'identity') +
  labs(title = 'Average Success Rate Per Mouse', x = "Average Success Rate", y = "Mouse Name", fill = "Mouse Name")

This bar chart visualizes the average success rate per mouse across all recorded sessions. The x-axis represents the mouse names (Cori, Forssmann, Hench, and Lederberg), while the y-axis indicates the average success rate, which is the proportion of trials where the mouse made a correct decision.

From the chart, Lederberg has the highest average success rate, meaning this mouse performed the best in correctly responding to stimuli. Forssmann and Hench have similar success rates, slightly lower than Lederberg but still relatively high. Cori has the lowest success rate among the four, suggesting it may have struggled more in making correct decisions.

This pattern is interesting because, in previous charts, Forssmann had the highest neuron activation, while Hench had the most activated brain regions. However, neither of these mice had the highest success rate, suggesting that increased neural activation does not necessarily translate to better task performance. This visualization helps show potential differences in behavioral performance across the mice, which may be influenced by individual learning capabilities, neural processing, or experimental conditions.

Neural Activity Analysis

spike_count_summary = function(trial_index, session_data) {
  spikes = session_data$spks[[trial_index]]
  brain_regions = session_data$brain_area
  spike_count = rowSums(spikes)
  avg_spike_per_region = tapply(spike_count, brain_regions, mean)
  return(avg_spike_per_region)
}

selected_session = session[[9]]  # Choosing session 9 for analysis
spike_data = spike_count_summary(10, selected_session)

head(spike_data)
##      CA1      CA3       LD      LSr     ORBm       PL 
## 1.055556 1.593023 2.556962 1.636364 1.401639 1.660000

This code performs an analysis of neural activity by computing the average spike count per brain region for a specific trial within a selected session. The function spike_count_summary takes in a trial index and session data, extracts the spike activity, and groups the spikes by brain region to calculate the average spike count for each region.

The displayed output shows the average spike count per brain region for trial 10 in session 9. Each column represents a specific brain region (e.g., CA1, CA3, LD, etc.), while the corresponding values represent the mean spike count for neurons in that region. For example, the LD region has the highest spike count at 2.556962, while ORBm has a lower activation at 1.401639.

This analysis provides insights into how different brain regions respond during a specific trial, helping to identify which areas exhibit stronger neural activity. This could be useful for understanding how certain stimuli influence neural responses in different regions of the brain.

trial_spike_df = as_tibble(
  matrix(ncol = length(spike_data) + 1, nrow = length(selected_session$feedback_type))
)
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if
## `.name_repair` is omitted as of tibble 2.0.0.
## ℹ Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
colnames(trial_spike_df) = c(names(spike_data), 'trial_id')

for (t in 1:length(selected_session$feedback_type)) {
  trial_spike_df[t, ] = as.list(c(spike_count_summary(t, selected_session), t))
}

trial_spike_df %>%
  pivot_longer(cols = -trial_id, names_to = 'Brain Region', values_to = 'Spike Count') %>%
  ggplot(aes(x = trial_id, y = `Spike Count`, color = `Brain Region`)) +
  geom_line() +
  geom_smooth(method = 'loess') +
  labs(title = 'Spike Activity Over Trials', x = "Trial ID")
## `geom_smooth()` using formula = 'y ~ x'

This plot visualizes spike activity over trials for different brain regions. The x-axis represents trial ID, which denotes the sequence of trials within a session, while the y-axis represents the spike count, showing the average number of spikes recorded in different brain regions for each trial. Each colored line corresponds to a different brain region, with a smoothed trend line (LOESS smoothing) added to highlight general patterns over trials.

Across trials, some brain regions display significant fluctuations in spike activity, while others remain relatively stable throughout the session. The LOESS smoothing reveals distinct neural activation trends over time, with certain regions showing a gradual increase or decrease in spike activity. Notably, the VPL and VISI regions exhibit consistently high spike counts compared to other regions, suggesting stronger or more persistent neural engagement in these areas. Additionally, while some regions, such as root, demonstrate a decline in activity over trials, others follow an increasing or irregular fluctuating pattern, highlighting potential differences in neural response dynamics across the session.

This visualization is to help with understanding how neural activity evolves over the course of an experiment, providing insights into brain region dynamics during different stages of the trials.

# 40 since the spks have only 40 columns
data_name = paste0("data", as.character(1:40))

get_trial_data = function(session_id, trial_id){
  spikes = session[[session_id]]$spks[[trial_id]]
  if (any(is.na(spikes))){
    disp("value missing")
  }

  trial_bin_average = matrix(colMeans(spikes), nrow = 1)
  colnames(trial_bin_average) = data_name
  trial_tibble  = as_tibble(trial_bin_average )%>% 
    add_column("trial_id" = trial_id) %>% 
    add_column("contrast_left"= session[[session_id]]$contrast_left[trial_id]) %>% 
    add_column("contrast_right"= session[[session_id]]$contrast_right[trial_id]) %>% 
    add_column("feedback_type"= session[[session_id]]$feedback_type[trial_id])
  
  return(trial_tibble)
}

get_session_usable_data = function(session_id){
  n_trial = length(session[[session_id]]$spks)
  trial_list = list()
  for (trial_id in 1:n_trial){
    trial_tibble = get_trial_data(session_id,trial_id)
    trial_list[[trial_id]] = trial_tibble
  }
  session_tibble = as_tibble(do.call(rbind, trial_list))
  session_tibble = session_tibble %>%
    add_column("mouse_name" = session[[session_id]]$mouse_name) %>%
    add_column("date_exp" = session[[session_id]]$date_exp) %>%
    add_column("session_id" = session_id)
  return(session_tibble)
}

session_list = list()
for (session_id in 1: 18){
  session_list[[session_id]] = get_session_usable_data(session_id)
}

full_data_tibble = as_tibble(do.call(rbind, session_list))
full_data_tibble$session_id = as.factor(full_data_tibble$session_id )
full_data_tibble$contrast_diff = abs(full_data_tibble$contrast_left-full_data_tibble$contrast_right)

# Success for EDA plots
full_data_tibble$success = full_data_tibble$feedback_type == 1
full_data_tibble$success = as.numeric(full_data_tibble$success)

summary(full_data_tibble)
##      data1              data2              data3             data4         
##  Min.   :0.002566   Min.   :0.004667   Min.   :0.00000   Min.   :0.002334  
##  1st Qu.:0.019837   1st Qu.:0.019785   1st Qu.:0.01984   1st Qu.:0.019837  
##  Median :0.027304   Median :0.026927   Median :0.02730   Median :0.027426  
##  Mean   :0.029801   Mean   :0.029610   Mean   :0.02978   Mean   :0.029935  
##  3rd Qu.:0.036339   3rd Qu.:0.036301   3rd Qu.:0.03645   3rd Qu.:0.036339  
##  Max.   :0.143713   Max.   :0.095315   Max.   :0.13259   Max.   :0.100086  
##                                                                            
##      data5              data6             data7              data8         
##  Min.   :0.003501   Min.   :0.00177   Min.   :0.004219   Min.   :0.002566  
##  1st Qu.:0.020350   1st Qu.:0.02044   1st Qu.:0.021004   1st Qu.:0.021239  
##  Median :0.027658   Median :0.02826   Median :0.029101   Median :0.029712  
##  Mean   :0.030042   Mean   :0.03060   Mean   :0.031788   Mean   :0.032586  
##  3rd Qu.:0.036802   3rd Qu.:0.03769   3rd Qu.:0.039340   3rd Qu.:0.040622  
##  Max.   :0.104585   Max.   :0.12924   Max.   :0.126010   Max.   :0.132472  
##                                                                            
##      data9              data10             data11             data12        
##  Min.   :0.003501   Min.   :0.003501   Min.   :0.003422   Min.   :0.004219  
##  1st Qu.:0.021574   1st Qu.:0.021574   1st Qu.:0.022284   1st Qu.:0.022880  
##  Median :0.030457   Median :0.030641   Median :0.030717   Median :0.030822  
##  Mean   :0.033015   Mean   :0.033185   Mean   :0.033263   Mean   :0.033332  
##  3rd Qu.:0.041547   3rd Qu.:0.041723   3rd Qu.:0.041723   3rd Qu.:0.041121  
##  Max.   :0.103393   Max.   :0.142123   Max.   :0.099315   Max.   :0.103393  
##                                                                             
##      data13             data14             data15             data16        
##  Min.   :0.004277   Min.   :0.004277   Min.   :0.003968   Min.   :0.005133  
##  1st Qu.:0.022880   1st Qu.:0.022923   1st Qu.:0.023207   1st Qu.:0.023213  
##  Median :0.030822   Median :0.031115   Median :0.031570   Median :0.031979  
##  Mean   :0.033620   Mean   :0.033741   Mean   :0.034147   Mean   :0.034658  
##  3rd Qu.:0.041723   3rd Qu.:0.041878   3rd Qu.:0.042351   3rd Qu.:0.043174  
##  Max.   :0.136869   Max.   :0.119760   Max.   :0.136986   Max.   :0.111470  
##                                                                             
##      data17             data18             data19             data20        
##  Min.   :0.003501   Min.   :0.005088   Min.   :0.005137   Min.   :0.005988  
##  1st Qu.:0.023810   1st Qu.:0.023891   1st Qu.:0.023891   1st Qu.:0.024141  
##  Median :0.032301   Median :0.032423   Median :0.032995   Median :0.033069  
##  Mean   :0.035099   Mean   :0.035241   Mean   :0.035640   Mean   :0.035793  
##  3rd Qu.:0.044080   3rd Qu.:0.044369   3rd Qu.:0.044944   3rd Qu.:0.044860  
##  Max.   :0.109855   Max.   :0.116317   Max.   :0.118006   Max.   :0.096931  
##                                                                             
##      data21             data22             data23             data24        
##  Min.   :0.004219   Min.   :0.004219   Min.   :0.005133   Min.   :0.004667  
##  1st Qu.:0.024141   1st Qu.:0.024226   1st Qu.:0.024226   1st Qu.:0.024744  
##  Median :0.033645   Median :0.033628   Median :0.033755   Median :0.033645  
##  Mean   :0.035983   Mean   :0.036157   Mean   :0.036103   Mean   :0.036414  
##  3rd Qu.:0.045685   3rd Qu.:0.045845   3rd Qu.:0.045685   3rd Qu.:0.045845  
##  Max.   :0.101777   Max.   :0.101777   Max.   :0.107833   Max.   :0.106624  
##                                                                             
##      data25             data26             data27             data28        
##  Min.   :0.001711   Min.   :0.003957   Min.   :0.004219   Min.   :0.004219  
##  1st Qu.:0.024504   1st Qu.:0.024744   1st Qu.:0.024141   1st Qu.:0.024141  
##  Median :0.033708   Median :0.033755   Median :0.033645   Median :0.033069  
##  Mean   :0.036478   Mean   :0.036309   Mean   :0.036252   Mean   :0.036005  
##  3rd Qu.:0.046018   3rd Qu.:0.046018   3rd Qu.:0.046233   3rd Qu.:0.045760  
##  Max.   :0.099695   Max.   :0.109495   Max.   :0.092643   Max.   :0.096931  
##                                                                             
##      data29             data30             data31             data32        
##  Min.   :0.003422   Min.   :0.004219   Min.   :0.002334   Min.   :0.005133  
##  1st Qu.:0.024744   1st Qu.:0.024771   1st Qu.:0.024744   1st Qu.:0.024141  
##  Median :0.033647   Median :0.033426   Median :0.033628   Median :0.033276  
##  Mean   :0.036218   Mean   :0.036109   Mean   :0.035832   Mean   :0.035749  
##  3rd Qu.:0.045778   3rd Qu.:0.045508   3rd Qu.:0.044521   3rd Qu.:0.044416  
##  Max.   :0.114441   Max.   :0.095368   Max.   :0.097603   Max.   :0.106624  
##                                                                             
##      data33            data34            data35             data36        
##  Min.   :0.00211   Min.   :0.00531   Min.   :0.005988   Min.   :0.005291  
##  1st Qu.:0.02414   1st Qu.:0.02423   1st Qu.:0.024299   1st Qu.:0.024504  
##  Median :0.03335   Median :0.03271   Median :0.032995   Median :0.032787  
##  Mean   :0.03564   Mean   :0.03547   Mean   :0.035401   Mean   :0.035293  
##  3rd Qu.:0.04494   3rd Qu.:0.04442   3rd Qu.:0.044415   3rd Qu.:0.044413  
##  Max.   :0.11632   Max.   :0.09693   Max.   :0.101777   Max.   :0.092084  
##                                                                           
##      data37             data38             data39             data40        
##  Min.   :0.005988   Min.   :0.004667   Min.   :0.004219   Min.   :0.005988  
##  1st Qu.:0.024744   1st Qu.:0.024112   1st Qu.:0.024112   1st Qu.:0.023973  
##  Median :0.032844   Median :0.032710   Median :0.032498   Median :0.032423  
##  Mean   :0.035187   Mean   :0.034985   Mean   :0.035008   Mean   :0.035028  
##  3rd Qu.:0.044304   3rd Qu.:0.043640   3rd Qu.:0.043640   3rd Qu.:0.044341  
##  Max.   :0.096931   Max.   :0.153473   Max.   :0.092084   Max.   :0.092466  
##                                                                             
##     trial_id     contrast_left    contrast_right   feedback_type    
##  Min.   :  1.0   Min.   :0.0000   Min.   :0.0000   Min.   :-1.0000  
##  1st Qu.: 71.0   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:-1.0000  
##  Median :143.0   Median :0.2500   Median :0.2500   Median : 1.0000  
##  Mean   :151.6   Mean   :0.3419   Mean   :0.3241   Mean   : 0.4202  
##  3rd Qu.:218.0   3rd Qu.:0.5000   3rd Qu.:0.5000   3rd Qu.: 1.0000  
##  Max.   :447.0   Max.   :1.0000   Max.   :1.0000   Max.   : 1.0000  
##                                                                     
##   mouse_name          date_exp           session_id   contrast_diff   
##  Length:5081        Length:5081        10     : 447   Min.   :0.0000  
##  Class :character   Class :character   15     : 404   1st Qu.:0.0000  
##  Mode  :character   Mode  :character   9      : 372   Median :0.5000  
##                                        11     : 342   Mean   :0.4229  
##                                        12     : 340   3rd Qu.:0.7500  
##                                        13     : 300   Max.   :1.0000  
##                                        (Other):2876                   
##     success      
##  Min.   :0.0000  
##  1st Qu.:0.0000  
##  Median :1.0000  
##  Mean   :0.7101  
##  3rd Qu.:1.0000  
##  Max.   :1.0000  
## 
head(full_data_tibble)
## # A tibble: 6 × 49
##    data1  data2  data3  data4  data5  data6  data7  data8  data9 data10 data11
##    <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
## 1 0.0490 0.0368 0.0177 0.0150 0.0327 0.0286 0.0313 0.0123 0.0341 0.0191 0.0463
## 2 0.0300 0.0313 0.0341 0.0272 0.0259 0.0313 0.0218 0.0232 0.0232 0.0341 0.0272
## 3 0.0490 0.0504 0.0300 0.0436 0.0245 0.0409 0.0300 0.0381 0.0341 0.0422 0.0559
## 4 0.0559 0.0531 0.0272 0.0613 0.0572 0.0599 0.0450 0.0286 0.0395 0.0354 0.0368
## 5 0.0272 0.0436 0.0313 0.0245 0.0450 0.0381 0.0463 0.0572 0.0477 0.0163 0.0272
## 6 0.0490 0.0218 0.0163 0.0109 0.0123 0.0232 0.0272 0.0327 0.0163 0.0191 0.0300
## # ℹ 38 more variables: data12 <dbl>, data13 <dbl>, data14 <dbl>, data15 <dbl>,
## #   data16 <dbl>, data17 <dbl>, data18 <dbl>, data19 <dbl>, data20 <dbl>,
## #   data21 <dbl>, data22 <dbl>, data23 <dbl>, data24 <dbl>, data25 <dbl>,
## #   data26 <dbl>, data27 <dbl>, data28 <dbl>, data29 <dbl>, data30 <dbl>,
## #   data31 <dbl>, data32 <dbl>, data33 <dbl>, data34 <dbl>, data35 <dbl>,
## #   data36 <dbl>, data37 <dbl>, data38 <dbl>, data39 <dbl>, data40 <dbl>,
## #   trial_id <int>, contrast_left <dbl>, contrast_right <dbl>, …

This code processes neural spike data across multiple trials and sessions to create a structured dataset for further analysis. It defines two functions: get_trial_data, which extracts spike activity for a given trial, computes the average spike count across time bins, and attaches relevant trial information such as contrast levels and feedback type; and get_session_usable_data, which applies get_trial_data to all trials within a session and appends session-level metadata, including the mouse name, experiment date, and session ID. The script then iterates over 18 sessions, compiling the processed trial data into a single tibble (full_data_tibble). Additional transformations are applied, such as computing contrast differences and defining a binary success indicator based on feedback type. The final dataset consolidates all trials and sessions, making it suitable for statistical analysis and predictive modeling.

feedback_counts = full_data_tibble %>%
  count(feedback_type)

# Plot the distribution of feedback_type
ggplot(feedback_counts, aes(x = as.factor(feedback_type), y = n, fill = as.factor(feedback_type))) +
  geom_bar(stat = "identity", color = "black", alpha = 0.7) +
  scale_fill_manual(values = c("red2", "skyblue")) +
  labs(title = "Distribution of Feedback Type (Imbalance Visualization)",
       x = "Feedback Type",
       y = "Count",
       fill = "Feedback Type")

I decided to find out if my planned response variable (feedback_type) is imbalanced or not. I decided to make a plot that visualizes the distribution of the feedback type variable in the dataset, highlighting any possible imbalance between the two classes. The x-axis represents the feedback type, where -1 (incorrect trials) and 1 (correct trials) are the two possible outcomes. The y-axis represents the count of occurrences for each feedback type. The blue bar corresponds to 1 (correct trials), and the red bar corresponds to -1 (incorrect trials). The significant height difference between the two bars indicates that there are far more correct trials than incorrect ones, meaning the dataset is imbalanced.

This imbalance can the models we could use, as classifiers may favor the majority class (1), leading to biased predictions. It suggests that techniques like oversampling, undersampling, or class-weighted models might be necessary to improve performance.

ggplot(full_data_tibble, aes(x = rowSums(select(full_data_tibble, starts_with("data"))), 
                                   fill = as.factor(session_id))) +
  geom_histogram(bins = 100, alpha = 0.6, position = "identity") +
  labs(title = "Distribution of Neural Spike Counts Per Trial Across Sessions",
       x = "Total Spikes per Trial",
       y = "Frequency",
       fill = "Session")

This histogram visualizes the distribution of neural spike counts per trial across multiple sessions. The x-axis represents the total spikes per trial, while the y-axis represents the frequency, or the number of trials that recorded a given spike count. Each session is color-coded, as indicated by the legend on the right, allowing for a comparison of spike distributions across different experimental sessions.

From the visualization, we can see that different sessions exhibit varying distributions of spike counts. Some sessions, particularly those represented by colors concentrated on the left side of the graph (e.g., green), have a higher frequency of low spike counts, while other sessions, such as those represented in blue and purple, have a broader spread of spike counts, extending further along the x-axis. This suggests variability in neural activity between sessions, possibly influenced by experimental conditions, stimulus variations, or differences in mouse behavior.

The overlapping nature of the histogram indicates that while most sessions share a common range of spike counts, some sessions deviate, exhibiting unique distributions. This kind of analysis is useful for identifying trends in neural responsiveness and assessing whether certain sessions exhibit significantly different firing patterns.

ggplot(full_data_tibble, aes(x = as.factor(session_id), y = rowSums(select(full_data_tibble, starts_with("data"))), fill = as.factor(session_id))) +
  geom_boxplot(outlier.color = "red") +
  labs(title = "Boxplot of Neural Spike Counts Across Sessions",
       x = "Session",
       y = "Total Spikes per Trial",
       fill = "Session")

This boxplot provides a visual representation of the distribution of neural spike counts per trial across different sessions. The interquartile range (IQR), represented by the box, captures the middle 50% of spike count values for each session, while the horizontal line inside the box indicates the median spike count. The whiskers extend to 1.5 times the IQR, and any data points beyond this range are plotted as red dots, signifying outliers—trials where spike counts deviated significantly from the typical distribution for that session.

From the visualization, there is noticeable variation in spike activity between sessions. For example, session 12 has a higher median spike count and a wider distribution, while session 6 exhibits a much lower median spike count with a narrow spread. Some sessions, such as sessions 12 and 7, have multiple outliers, suggesting that certain trials in these sessions had exceptionally high or low spike counts. Additionally, sessions with larger IQRs, such as sessions 7 and 12, show greater variability in neural activity across trials, whereas sessions with smaller IQRs, such as sessions 6 and 14, have more consistent neural firing patterns. This analysis helps in identifying session-to-session differences, potential anomalies, and overall trends in neural spike activity.

pca_result = prcomp(full_data_tibble[, 1:40], center = TRUE, scale = TRUE)
pca_df = as_tibble(pca_result$x)
pca_df$session_id = full_data_tibble$session_id
pca_df$mouse_name = full_data_tibble$mouse_name

pca_df %>%
  ggplot(aes(x = PC1, y = PC2, color = as.factor(session_id))) +
  geom_point() +
  labs(title = 'PCA: PC1 vs PC2 by Session', col = "Session ID")

I decided to perform a Principal Component Analysis (PCA) to reduce the dimensionality of the neural data while preserving as much variance as possible. Given that the dataset contains a large number of neural activity features, PCA helps to transform the data into a lower-dimensional space, making it easier to visualize and analyze patterns. This method allows us to explore whether trials from different sessions exhibit distinct clusters or if there is significant overlap between them. Additionally, PCA helps to identify underlying trends in the data that might not be immediately apparent in the high-dimensional space.

The scatter plot represents the first two principal components (PC1 and PC2) of the dataset, with each point corresponding to a trial, and colors indicating different session IDs. The spread of points suggests that the first two principal components capture a meaningful amount of variance in the data. There is some noticeable clustering, but a significant overlap between sessions indicates that session-level differences might not be the most dominant source of variation in the dataset. The rightward concentration of points suggests that PC1 explains a substantial amount of variance, while PC2 introduces some additional separation. Some sessions, such as those represented in blue and brown, appear more dispersed, possibly indicating greater variability in neural activity within those sessions. Overall, this PCA visualization provides insight into the structure of the neural data and suggests that further analysis, such as clustering or additional feature engineering, may be necessary to extract clearer patterns.

pca_df %>%
  ggplot(aes(x = PC1, y = PC2, color = as.factor(mouse_name))) +
  geom_point() +
  labs(title = 'PCA: PC1 vs PC2 by Mouse Name', col = "Mouse Name")

This scatter plot represents the PCA projection of the neural data, with each point corresponding to a trial and colored based on the mouse name instead of session id as in the previous plot. The x-axis represents the first principal component (PC1), and the y-axis represents the second principal component (PC2). By reducing the high-dimensional neural data into two principal components, this plot helps visualize patterns and potential clustering based on different mice.

From the visualization, there is a substantial overlap between the different mice, suggesting that neural activity, as captured by the first two principal components, does not strongly separate by mouse identity. However, some mice exhibit noticeable distributions in specific regions of the plot. Lederberg (purple) dominates the right side of the plot, while Cori (red) and Forssmann (green) are more dispersed toward the left. Hench (blue-green) appears to have a wider spread but overlaps with other mice.

The clustering patterns suggest that while there may be some mouse-specific differences in neural activity, the overall variance in the dataset is not primarily explained by mouse identity. This indicates that other factors, such as trial conditions, session variability, or task performance, may play a more significant role in distinguishing neural patterns. Further analysis, such as incorporating additional features or clustering techniques, might be necessary to uncover stronger mouse-specific trends.

library(Rtsne)
set.seed(123)
tsne_result = Rtsne(full_data_tibble[, 1:40])
tsne_df = as_tibble(tsne_result$Y)
tsne_df$session_id = full_data_tibble$session_id
tsne_df$mouse_name = full_data_tibble$mouse_name

I then decided to use t-Distributed Stochastic Neighbor Embedding (t-SNE) to better capture the nonlinear structure in the neural data, as PCA is limited to linear transformations and may not fully separate complex patterns. Since neural spike data is high-dimensional and likely contains intricate relationships, t-SNE helps visualize clusters by preserving local similarities between data points. Unlike PCA, which primarily maximizes variance, t-SNE is particularly useful for detecting subtle groupings that may correspond to different mice, sessions, or behavioral responses. This approach provides a more intuitive and interpretable low-dimensional representation of the neural activity.

# Visualizing clusters
tsne_df %>%
  ggplot(aes(x = V1, y = V2, color = as.factor(mouse_name))) +
  geom_point() +
  labs(title = 't-SNE Representation of Neural Data', col = "Mouse Name")

This scatter plot represents the t-SNE of the neural data, where each point corresponds to a trial and is colored according to the mouse name. The x-axis (V1) and y-axis (V2) represent the two components generated by t-SNE, which map high-dimensional neural activity data into a two-dimensional space while preserving local relationships.

Unlike PCA, which captures global variance, t-SNE focuses on preserving clusters and local structures within the data. The plot shows that while some separation is visible, particularly for Lederberg (purple) in the upper region and Hench (blue-green) forming a distinct cluster on the left, there is still substantial overlap between mice. This suggests that while some mouse-specific patterns exist in neural activity, they are not completely separable in the reduced space. The dense clustering indicates that neural activity patterns share similarities across mice, possibly influenced more by experimental conditions rather than individual identity. This visualization helps assess whether mouse identity significantly impacts neural firing patterns or if other features, such as stimulus conditions or task performance, play a larger role in shaping neural responses.

set.seed(123)
kmeans_cluster = kmeans(tsne_result$Y, centers = 4)
tsne_df$cluster = as.factor(kmeans_cluster$cluster)
tsne_df %>%
  ggplot(aes(x = V1, y = V2, color = cluster)) +
  geom_point() +
  labs(title = 'K-means Clustering on t-SNE Reduced Data', col = "Cluster")

This plot represents the results of K-means clustering performed on the t-SNE reduced neural data. The x-axis (V1) and y-axis (V2) represent the two dimensions generated by t-SNE, which was used to project the high-dimensional neural activity data into a more interpretable space. Each point corresponds to a trial, and its color indicates the cluster assignment determined by K-means with four clusters.

The K-means algorithm grouped the data into four distinct clusters, attempting to partition similar trials together based on their neural activity patterns. The output shows that the clusters form layered, horizontal bands, suggesting that t-SNE effectively captured underlying structure in the data, but the separation might not be as distinct as expected. Some degree of overlap between clusters indicates that there is still some continuity in the data rather than perfectly discrete groupings.

This clustering analysis helps assess whether distinct neural response patterns exist in the dataset. If these clusters correspond to meaningful differences, such as different behavioral responses, experimental conditions, or mouse identities, it could suggest underlying structure in the neural activity. However, the relatively gradual transition between clusters suggests that additional tuning of hyperparameters (e.g., the number of clusters) or alternative clustering methods may be necessary to extract more well-defined groups.

After analyzing the data, we can now move forward with modeling. Before starting this phase, we need to format our data appropriately.


Data Integration

set.seed(123)  # For reproducibility

# Selecting relevant features
# predictor_columns = c(paste0("data", 1:40), "contrast_left", "contrast_right", "contrast_diff")

# sample = sample(c(TRUE, FALSE), nrow(full_data_tibble), replace=TRUE, prob=c(0.8,0.2))
# train_data = full_data_tibble[sample, ]
# test_data = full_data_tibble[!sample, ]

train_indices = sample(1:nrow(full_data_tibble), size = 0.8 * nrow(full_data_tibble), replace = FALSE)
train_data = full_data_tibble[train_indices, ]
test_data = full_data_tibble[-train_indices, ]

# remove the non-predictors
X_train = train_data %>%
  select(-c("trial_id", "feedback_type", "mouse_name", "date_exp", "session_id", "contrast_diff", "success"))

Y_train = train_data %>%
  select("feedback_type") %>%
  pull() # Modify to a vector

X_test = test_data %>%
  select(-c("trial_id", "feedback_type", "mouse_name", "date_exp", "session_id", "contrast_diff", "success"))

Y_test = test_data %>%
  select("feedback_type") %>%
  pull() # Modify to a vector


# Standardize numeric features (only for SVM, KNN, Logistic Regression)
preprocess_params = preProcess(X_train, method = c("center", "scale"))  # Compute means & std devs

# Apply standardization
X_train_scaled = predict(preprocess_params, X_train)
X_test_scaled = predict(preprocess_params, X_test)

# Convert response variable to a factor for classification models
Y_train_factor = as.factor(Y_train)
Y_test_factor = as.factor(Y_test)

# Define 5-Fold Cross-Validation
# cv_control = trainControl(method = "cv", number = 5, savePredictions = TRUE)

This code prepares the dataset for predictive modeling by splitting it into training and testing sets, selecting relevant features, and normalizing numerical variables. First, an 80-20 random split is applied to divide the data into train_data and test_data, ensuring that model training and evaluation are done on separate subsets. To focus on predictive features, non-essential columns such as trial_id, feedback_type, mouse_name, date_exp, and session_id are removed from X_train and X_test, leaving only neural spike data and contrast-related variables. The response variable, feedback_type, is extracted separately as Y_train and Y_test in vector form, making it easier to use in classification models.

Additionally a standardized version of X_train and X_test are defined as Logistic Regression, SVM, and KNN require standardized data.


Predictive Modeling

Logistic Regression Model

Training Log Reg model

# Modified X_train basically but it includes a feedback_type column
# train_log = train_data %>%
#   select(-c("trial_id", "mouse_name", "date_exp", "session_id", "contrast_diff", "success"))

# To convert to binary for log reg
# train_log$feedback_type[train_log$feedback_type < 0] = 0
Y_train_log = as.numeric(ifelse(Y_train == -1, 0, 1))

# Fit the logistic regression model
log_model = glm(Y_train_log ~ ., data = X_train_scaled,
                 family = binomial
                 )

# log_model_cv = train(
#   Y_train_log ~ ., 
#   data = X_train_scaled, 
#   method = "glm", 
#   family = binomial, 
#   trControl = cv_control
# )

# For a combined data frame
# log_model = glm(feedback_type ~ ., data = train_log, 
#                  family = binomial
#                  )

# Summarize the model
summary(log_model)
## 
## Call:
## glm(formula = Y_train_log ~ ., family = binomial, data = X_train_scaled)
## 
## Coefficients:
##                 Estimate Std. Error z value Pr(>|z|)    
## (Intercept)     0.980880   0.037230  26.347  < 2e-16 ***
## data1          -0.109541   0.065531  -1.672  0.09461 .  
## data2          -0.031326   0.070375  -0.445  0.65623    
## data3          -0.111928   0.069267  -1.616  0.10612    
## data4           0.062072   0.070257   0.883  0.37697    
## data5          -0.178803   0.069906  -2.558  0.01053 *  
## data6          -0.022144   0.071484  -0.310  0.75673    
## data7          -0.021730   0.073261  -0.297  0.76677    
## data8          -0.016081   0.073639  -0.218  0.82713    
## data9          -0.116137   0.074096  -1.567  0.11703    
## data10         -0.015121   0.071761  -0.211  0.83311    
## data11         -0.236578   0.073656  -3.212  0.00132 ** 
## data12         -0.050799   0.074510  -0.682  0.49538    
## data13          0.037077   0.073212   0.506  0.61255    
## data14         -0.074562   0.072808  -1.024  0.30580    
## data15          0.048726   0.074704   0.652  0.51423    
## data16         -0.015645   0.075468  -0.207  0.83577    
## data17          0.113729   0.075830   1.500  0.13367    
## data18          0.023873   0.077465   0.308  0.75795    
## data19          0.192517   0.078605   2.449  0.01432 *  
## data20         -0.011125   0.075702  -0.147  0.88317    
## data21          0.057043   0.076545   0.745  0.45614    
## data22          0.042744   0.076764   0.557  0.57765    
## data23          0.043544   0.077376   0.563  0.57360    
## data24          0.085365   0.076652   1.114  0.26543    
## data25          0.036850   0.078732   0.468  0.63976    
## data26          0.151600   0.077035   1.968  0.04907 *  
## data27         -0.059277   0.078791  -0.752  0.45185    
## data28          0.056003   0.077633   0.721  0.47068    
## data29          0.010457   0.077545   0.135  0.89273    
## data30          0.063793   0.077444   0.824  0.41009    
## data31          0.032187   0.077480   0.415  0.67783    
## data32          0.083059   0.077201   1.076  0.28198    
## data33          0.191368   0.077434   2.471  0.01346 *  
## data34         -0.067479   0.076399  -0.883  0.37710    
## data35         -0.004042   0.074952  -0.054  0.95699    
## data36         -0.013496   0.075114  -0.180  0.85741    
## data37         -0.061366   0.073823  -0.831  0.40583    
## data38          0.296969   0.074706   3.975 7.03e-05 ***
## data39         -0.018015   0.073213  -0.246  0.80563    
## data40         -0.088062   0.071900  -1.225  0.22066    
## contrast_left  -0.003776   0.036799  -0.103  0.91826    
## contrast_right -0.076130   0.038278  -1.989  0.04671 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 4889.7  on 4063  degrees of freedom
## Residual deviance: 4627.7  on 4021  degrees of freedom
## AIC: 4713.7
## 
## Number of Fisher Scoring iterations: 4
# Use model to make predictions
log_probs = predict(log_model, X_test_scaled, type = "response")
# log_probs_cv = predict(log_model_cv, X_test, type = "response")

# Convert probabilities to class labels (1 if prob > 0.5, else -1)
# log_pred = ifelse(log_probs > 0.5, 1, -1)

# (1 if prob > 0.5, else 0)
log_pred = ifelse(log_probs > 0.5, 1, 0)
# log_pred_cv = ifelse(log_probs_cv > 0.5, 1, 0)

# Convert predictions and actual labels to factors
log_pred = as.factor(log_pred)
# log_pred_cv = as.factor(log_pred_cv)
# log_pred = as.factor(ifelse(log_pred == -1, 0, 1))

Y_test_log = as.numeric(ifelse(Y_test == -1, 0, 1))
Y_test_log = as.factor(Y_test_log)

This code implements a logistic regression model to predict the feedback type of a mouse trial based on neural spike activity and contrast features. First, it modifies the response variable (Y_train) to be binary (0 and 1 instead of -1 and 1), making it compatible with logistic regression. The model is then trained on X_train_scaled (a standardized version of X_train as required for log reg), which contains only predictive features after removing categorical and non-relevant columns. After training, the model’s coefficients and statistical significance are examined using the summary(log_model) function. Finally, the trained model is used to predict probabilities for the standardized test set (X_test_scaled), which are then converted into binary class labels (1 if probability > 0.5, otherwise 0). The predictions and actual test labels are also converted into factors to facilitate evaluation.

The model output provides insights into the relationship between predictor variables and the response variable (feedback_type). The coefficients represent the effect of each feature on the probability of positive feedback (1), with positive values increasing and negative values decreasing this likelihood. The statistical significance of each feature is assessed through p-values, where smaller values indicate stronger evidence of an association with the response variable. Several features, including data5 (p = 0.01053), data11 (p = 0.00132), data19 (p = 0.01432), data26 (p = 0.04907), data33 (p = 0.01346), data38 (p < 0.0001), and contrast_right (p = 0.04671), have p-values below 0.05, suggesting they significantly influence the model’s predictions. Meanwhile, many other features show high p-values, indicating weaker associations. The Akaike Information Criterion (AIC) score of 4713.7 and the residual deviance of 4627.7 suggest the model fits the data reasonably well, though potential improvements through feature selection or regularization could further refine its performance.

Confusion Matrix

# Compute confusion matrix and accuracy
conf_matrix = confusionMatrix(log_pred, Y_test_log)
accuracy = conf_matrix$overall["Accuracy"]

# Print results
print(conf_matrix)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  15  10
##          1 282 710
##                                          
##                Accuracy : 0.7129         
##                  95% CI : (0.684, 0.7405)
##     No Information Rate : 0.708          
##     P-Value [Acc > NIR] : 0.3798         
##                                          
##                   Kappa : 0.0501         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.05051        
##             Specificity : 0.98611        
##          Pos Pred Value : 0.60000        
##          Neg Pred Value : 0.71573        
##              Prevalence : 0.29204        
##          Detection Rate : 0.01475        
##    Detection Prevalence : 0.02458        
##       Balanced Accuracy : 0.51831        
##                                          
##        'Positive' Class : 0              
## 
print(paste("Logistic Regression Accuracy:", accuracy))
## [1] "Logistic Regression Accuracy: 0.712881022615536"
# Convert confusion matrix to a data frame
conf_matrix_df = as.data.frame.table(conf_matrix$table)

# Rename columns for clarity
colnames(conf_matrix_df) = c("Actual_Class", "Predicted_Class", "Frequency")

# Convert frequency to numeric
conf_matrix_df$Frequency = as.numeric(conf_matrix_df$Frequency)

# Plot the confusion matrix
ggplot(data = conf_matrix_df, aes(x = Predicted_Class, y = Actual_Class, fill = Frequency)) +
  geom_tile(color = "white") +  # Creates the heatmap
  geom_text(aes(label = Frequency), vjust = 0.5, size = 5) +  # Adds text labels for each cell
  scale_fill_gradient(low = "white", high = "red3") +  # Color gradient for intensity
  labs(
    x = "Predicted Class",
    y = "Actual Class",
    title = "Confusion Matrix for Logistic Regression Model"
  )

The output indicates an overall accuracy of 71.29%, meaning the model correctly classifies about 71% of test cases. The confusion matrix reveals 710 true positives and 15 true negatives, but also 282 false negatives, highlighting the model’s tendency to misclassify negative feedback trials as positive.

The performance metrics show a sensitivity of 0.05051, which is significantly low, indicating the model struggles to correctly identify negative feedback cases. Meanwhile, specificity is extremely high (0.98611), demonstrating that the model is highly effective at recognizing positive feedback. This imbalance suggests that the model is heavily skewed toward predicting positive feedback (feedback type = 1), frequently mislabeling negative cases. Additionally, the McNemar’s test p-value (<2e-16) confirms a significant disparity in misclassification rates between positive and negative feedback.

With a balanced accuracy of 0.51831, the model performs only slightly better than random guessing when differentiating between feedback types. The kappa score of 0.0501, which measures agreement between predictions and actual labels beyond chance, remains low, reinforcing weak classification performance. The large class imbalance in the dataset, as previously described in the report, is most likely the cause for these issues.

After evaluating the logistic regression model, I decided to test using a K-Nearest Neighbors (KNN) model to determine whether a non-parametric approach can improve classification accuracy.

K-Nearest Neighbors (KNN)

Find best K-value for KNN

# Function to compute accuracy for different values of K
calculate_knn_accuracy = function(k_value) {
  knn_pred = knn(train = X_train_scaled, test = X_test_scaled, cl = Y_train, k = k_value)
  return(mean(knn_pred == Y_test))  # Compute Accuracy
}

# Test K values from 1 to 100
k_values = 1:250
accuracies = sapply(k_values, calculate_knn_accuracy)

# Find Best K (max accuracy)
best_k = k_values[which.max(accuracies)]
print(paste("Best K:", best_k))
## [1] "Best K: 54"
# Plot Accuracy vs. K Value
accuracy_df = tibble(K = k_values, Accuracy = accuracies)

ggplot(accuracy_df, aes(x = K, y = Accuracy)) +
  geom_line(color = "skyblue") +
  geom_point(color = "red2") +
  labs(title = "KNN Accuracy vs. K Value",
       x = "Number of Neighbors (K)",
       y = "Accuracy")

A KNN model, unlike logistic regression assumes a linear relationship between features and the response variable, is a distance-based algorithm that classifies data points based on the majority class of their nearest neighbors. The effectiveness of KNN depends on the choice of K, the number of neighbors considered, making it essential to test various values to find the optimal one.

The code I wrote evaluates KNN performance for different values of K, ranging from 1 to 250. The calculate_knn_accuracy function applies the KNN algorithm to the training data (X_train_scaled) and evaluates its accuracy on the test data (X_test) for each K value. The accuracy for each K is stored in the accuracies vector, and the best K is selected as the one that maximizes accuracy. The output revealed that the optimal K is 54, meaning that using 54 neighbors produces the highest classification accuracy. I had decided to visualize how accuracy varies with different K values, using a line plot where sky blue represents accuracy trends, and red points highlight individual K values. This visualization helps in understanding how choosing the right K is crucial for KNN’s performance, as too small or too large a value can lead to overfitting or underfitting, respectively. Initially, accuracy starts relatively low but increases rapidly as K grows, peaking at around K = 54, where the model achieves its highest accuracy of approximately 72%. Beyond this point, accuracy begins to decline slightly and stabilizes around 70.5-71% for larger K values. The fluctuations at lower K values indicate that the model is highly sensitive to small changes in data, leading to overfitting. As K increases, the model generalizes better, but excessive smoothing at very high K values results in reduced performance.

Train KNN Model and Confusion Matrix

# Train final KNN model with the best K
knn_final = knn(train = X_train_scaled, test = X_test_scaled, cl = Y_train, k = best_k)

# Compute Confusion Matrix
conf_matrix_knn = confusionMatrix(as.factor(knn_final), as.factor(Y_test))

# Print Accuracy
print(conf_matrix_knn)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  -1   1
##         -1  21  12
##         1  276 708
##                                          
##                Accuracy : 0.7168         
##                  95% CI : (0.688, 0.7443)
##     No Information Rate : 0.708          
##     P-Value [Acc > NIR] : 0.2799         
##                                          
##                   Kappa : 0.0731         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.07071        
##             Specificity : 0.98333        
##          Pos Pred Value : 0.63636        
##          Neg Pred Value : 0.71951        
##              Prevalence : 0.29204        
##          Detection Rate : 0.02065        
##    Detection Prevalence : 0.03245        
##       Balanced Accuracy : 0.52702        
##                                          
##        'Positive' Class : -1             
## 
# Convert confusion matrix to a data frame
conf_matrix_knn_df = as.data.frame.table(conf_matrix_knn$table)

# Rename columns for clarity
colnames(conf_matrix_knn_df) = c("Actual_Class", "Predicted_Class", "Frequency")

# Convert frequency to numeric
conf_matrix_knn_df$Frequency = as.numeric(conf_matrix_knn_df$Frequency)

# Plot the confusion matrix
ggplot(data = conf_matrix_knn_df, aes(x = Predicted_Class, y = Actual_Class, fill = Frequency)) +
  geom_tile(color = "white") +  # Creates the heatmap
  geom_text(aes(label = Frequency), vjust = 0.5, size = 5) +  # Adds text labels for each cell
  scale_fill_gradient(low = "white", high = "red3") +  # Color gradient for intensity
  labs(
    x = "Predicted Class",
    y = "Actual Class",
    title = "Confusion Matrix for K-Nearest Neighbors Model"
  )

This code trains a final K-Nearest Neighbors (KNN) model using the previously determined optimal K value (54) and evaluates its performance using a confusion matrix. The model is trained on X_train_scaled and tested on X_test_scaled, with knn_final storing the predicted classifications. The confusion matrix is computed in a similar manner to the one computed for the logistic regression model.

The results show that the KNN model achieved an accuracy of 71.78%, slightly outperforming logistic regression (71.29%). While specificity remains high at 98.19%, meaning the model effectively identifies positive feedback (1), sensitivity is low at 7.74%, indicating that it still struggles significantly to classify negative feedback (-1). This imbalance suggests a continued bias toward predicting the majority class which is expected due to the imbalance inherent in the dataset.

The Kappa statistic (0.08), though slightly improved from logistic regression, still reflects weak agreement between predicted and actual values beyond chance. Additionally, McNemar’s test (p < 2e-16) highlights systematic misclassification errors, reinforcing the model’s difficulty in detecting negative feedback cases.

After testing the KNN model, I now shift our focus to Support Vector Machines (SVM) to see if it can improve classification performance. While KNN provided a slight improvement over logistic regression, its low sensitivity suggested that the model struggled with class imbalance. SVM, particularly with a radial kernel, is known for its ability to handle non-linear decision boundaries and may offer better classification by separating the data more effectively.

Support Vector Machine (SVM)

library(e1071)      # SVM package

SVM model

# Train model on feedback_type (Y_train) as target, everything else as predictors
svm_model = svm(Y_train ~ ., data = X_train_scaled, kernel = "radial")

# Make predictions
predictions = predict(svm_model, X_test_scaled)

predicted_labels = as.numeric(ifelse(predictions > 0.5, 1, -1))

# Evaluate the performance of the SVM classifier
accuracy = mean(predicted_labels == Y_test)

print(paste("SVM Accuracy:", accuracy))
## [1] "SVM Accuracy: 0.710914454277286"

The code trains an SVM model using Y_train as the target variable and all other features as predictors. The model utilizes a radial basis function (RBF) kernel, which enables it to capture nonlinear patterns within the data. After training, predictions are generated on X_test_scaled (a standardized version of X_test as required), with probability outputs converted into binary class labels (1 if above 0.5, -1 otherwise). The model’s accuracy is then calculated by comparing the predicted values to the actual labels.

The output indicates that the SVM model achieved an accuracy of 71.09% which slightly falls short of both logistic regression (71.29%) and KNN (71.68%). The similarity in accuracy across models highlights the challenge of classifying the data effectively, most likely due to the class imbalance.

Confusion Matrix

# Compute confusion matrix
conf_matrix_svm = confusionMatrix(factor(predicted_labels), factor(Y_test))

print(conf_matrix_svm)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  -1   1
##         -1  20  17
##         1  277 703
##                                          
##                Accuracy : 0.7109         
##                  95% CI : (0.682, 0.7386)
##     No Information Rate : 0.708          
##     P-Value [Acc > NIR] : 0.4334         
##                                          
##                   Kappa : 0.0589         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.06734        
##             Specificity : 0.97639        
##          Pos Pred Value : 0.54054        
##          Neg Pred Value : 0.71735        
##              Prevalence : 0.29204        
##          Detection Rate : 0.01967        
##    Detection Prevalence : 0.03638        
##       Balanced Accuracy : 0.52186        
##                                          
##        'Positive' Class : -1             
## 
# Convert confusion matrix to data frame
conf_matrix_df_svm = as.data.frame(conf_matrix_svm$table)

# Rename columns for clarity
colnames(conf_matrix_df_svm) = c("Actual_Class", "Predicted_Class", "Frequency")

# Convert to factors for proper ordering in ggplot
conf_matrix_df_svm$Actual_Class = as.factor(conf_matrix_df_svm$Actual_Class)
conf_matrix_df_svm$Predicted_Class = as.factor(conf_matrix_df_svm$Predicted_Class)

# Plot Confusion Matrix as a Heatmap
ggplot(data = conf_matrix_df_svm, aes(x = Predicted_Class, y = Actual_Class, fill = Frequency)) +
  geom_tile(color = "white") +  # Creates the heatmap
  geom_text(aes(label = Frequency), vjust = 0.5, size = 5) +  # Adds text labels
  scale_fill_gradient(low = "white", high = "red") +  # Color gradient for intensity
  labs(
    x = "Predicted Class",
    y = "Actual Class",
    title = "Confusion Matrix for SVM Model"
  )

This code evaluates the SVM model’s performance by computing and visualizing the confusion matrix, using the same approach as in the logistic regression and KNN models.

The SVM model exhibits a strong imbalance in classification performance, with a sensitivity of just 6.73%, meaning it correctly identifies only a small fraction of negative feedback cases (-1). In contrast, specificity is very high (97.64%), indicating the model is highly effective at detecting positive feedback (1). This discrepancy suggests a significant bias toward predicting the majority class, as reflected in the confusion matrix, where 277 false negatives far outnumber the 20 true negatives.

The Kappa statistic (0.0589) remains low, signaling that the model’s predictive performance is only marginally better than random guessing. Additionally, McNemar’s test (p < 2e-16) confirms a significant imbalance in misclassification rates between the two classes. The balanced accuracy of 52.19%, which accounts for both sensitivity and specificity, is only slightly above chance, reinforcing that the model struggles to distinguish between the two feedback types effectively.

While the overall accuracy of 71.09% might seem reasonable, it primarily reflects the model’s tendency to favor the majority class (1). Given this strong imbalance, the model’s ability to generalize effectively remains limited.

After evaluating the SVM model, which showed moderate accuracy but struggled with class imbalance, we now move to XGBoost, a more advanced boosting-based algorithm that often performs well in structured data tasks.

XGBoost

Creating an XGBoost model

library(xgboost)
# Convert to matrix because XGBoost requires a matrix (non-list)
# X_train_xgb = train_data %>%
#   select(-c("mouse_name", "feedback_type", "date_exp", "session_id", "success")) %>%
#   as.matrix()

# X_test_xgb = test_data %>%
#   select(-c("mouse_name", "feedback_type", "date_exp", "session_id", "success")) %>%
#   as.matrix()

X_train_xgb = as.matrix(X_train)
X_test_xgb = as.matrix(X_test)

# Convert to binary (because XGBoost requires it)
Y_train_xgb = as.numeric(ifelse(Y_train == -1, 0, 1))
Y_test_xgb = as.numeric(ifelse(Y_test == -1, 0, 1))

# Create XGBoost DMatrix
dtrain = xgb.DMatrix(data = X_train_xgb, label = Y_train_xgb)

# Train the XGBoost model
xgb_model = xgboost(data = dtrain,
                      objective = "binary:logistic",  # Binary classification
                      eval_metric = "auc",
                      nrounds = 35,
                      max_depth = 6,
                      eta = 0.1,  # Learning rate
                      verbose = 0
                     )

The code prepares the dataset for XGBoost training, ensuring that the input format aligns with the model’s requirements. First, the feature matrices (X_train_xgb and X_test_xgb) are converted into matrix format, as XGBoost does not support list-based structures. The target labels (Y_train_xgb and Y_test_xgb) are also transformed into a binary format (0 and 1) since XGBoost requires numerical labels for classification. The data is then stored in an optimized DMatrix format, which improves computational efficiency. Finally, an XGBoost model is trained using 35 boosting rounds, a maximum tree depth of 6, and a learning rate (eta) of 0.1, optimizing for the AUC (Area Under the Curve) metric. The binary:logistic objective function is used, as this is a binary classification task. This setup ensures that XGBoost can leverage the data effectively while maintaining computational efficiency.

# Convert test set to XGBoost DMatrix
X_test_xgb_D = xgb.DMatrix(data = X_test_xgb)

# Make probability predictions
predictions = predict(xgb_model, X_test_xgb_D)

# Convert probabilities to class labels (1 if prob > 0.5, else 0)
predicted_labels = as.numeric(ifelse(predictions > 0.5, 1, 0))

# Compute Accuracy
accuracy = mean(predicted_labels == Y_test_xgb)
print(paste("XGBoost Accuracy:", accuracy))
## [1] "XGBoost Accuracy: 0.724680432645034"

The accuracy calculated for XGBoost is 72.47% which is a slight improvement over all three previous predictive models (SVM, 71.09%), (logistic regression, 71.29%) and (KNN, 71.68%). This suggests that the XGBoost model performs similarly to the previous models (Logistic Regression, KNN, and SVM).

Confusion Matrix

# Compute confusion matrix
conf_matrix_xgb = confusionMatrix(factor(predicted_labels, levels = c(0, 1)), 
                                   factor(Y_test_xgb, levels = c(0, 1)))

# Print results
print(conf_matrix_xgb)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  43  26
##          1 254 694
##                                           
##                Accuracy : 0.7247          
##                  95% CI : (0.6961, 0.7519)
##     No Information Rate : 0.708           
##     P-Value [Acc > NIR] : 0.1273          
##                                           
##                   Kappa : 0.1403          
##                                           
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 0.14478         
##             Specificity : 0.96389         
##          Pos Pred Value : 0.62319         
##          Neg Pred Value : 0.73207         
##              Prevalence : 0.29204         
##          Detection Rate : 0.04228         
##    Detection Prevalence : 0.06785         
##       Balanced Accuracy : 0.55434         
##                                           
##        'Positive' Class : 0               
## 
# Convert confusion matrix to a data frame
conf_matrix_df_xgb = as.data.frame(conf_matrix_xgb$table)

# Rename columns
colnames(conf_matrix_df_xgb) = c("Actual_Class", "Predicted_Class", "Frequency")

# Convert to factors for proper ordering in ggplot
conf_matrix_df_xgb$Actual_Class = as.factor(conf_matrix_df_xgb$Actual_Class)
conf_matrix_df_xgb$Predicted_Class = as.factor(conf_matrix_df_xgb$Predicted_Class)

# Plot Confusion Matrix as a Heatmap
ggplot(data = conf_matrix_df_xgb, aes(x = Predicted_Class, y = Actual_Class, fill = Frequency)) +
  geom_tile(color = "white") +  # Creates the heatmap
  geom_text(aes(label = Frequency), vjust = 0.5, size = 5) +  # Adds text labels
  scale_fill_gradient(low = "white", high = "red2") +  # Color gradient for intensity
  labs(
    x = "Predicted Class",
    y = "Actual Class",
    title = "Confusion Matrix for XGBoost Model"
  )

Using the same confusion matrix analysis as the previous models, we evaluate the performance of the XGBoost model.

XGBoost on top of having the highest accuracy, it also maintained a high specificity (96.39%), meaning it effectively classified positive feedback cases (1). However, sensitivity remained low at 14.48%, indicating difficulty in correctly identifying negative feedback cases (0). The positive predictive value (62.32%) suggests that only about 62% of predicted negative cases were truly negative, while the negative predictive value (73.21%) confirms that most positive predictions were correct. The McNemar’s test p-value (<2e-16) reinforces that classification errors are significantly imbalanced, with the model favoring the majority class (1). Despite this, XGBoost achieved the highest Kappa score (0.1403) among all models, reflecting a slightly stronger agreement with actual labels beyond random chance.

Model Comparison and Conclusion

Among the four models evaluated (Logistic Regression, KNN, SVM, and XGBoost), XGBoost stands out as the best-performing model. While KNN, SVM, and Logistic Regression all demonstrated high specificity, they struggled with classifying negative feedback (0), showing an overwhelming bias toward predicting positive feedback (1).

Although XGBoost still exhibits this bias, it had the best overall balance across key metrics:

  • Highest accuracy (72.47%)
  • Highest sensitivity (14.48%), meaning it detected more negative cases than the other models.
  • Highest balanced accuracy (55.43%), which accounts for both sensitivity and specificity.
  • Highest Kappa score (0.1403), indicating slightly better predictive agreement beyond chance.

While XGBoost does not completely solve the class imbalance issue, it provides the most effective trade-off between accuracy, sensitivity, and specificity. This makes it the most reliable choice for classification in this context, outperforming the other three models in overall effectiveness.


Evaluation of Test Dataset

# Load the data
testRDS = list()
for(i in 1:2){
  file_path = paste('/Users/jovinlouie/Desktop/UC Davis/WQ 25/STA 141A/STA141AProject/Data/test/test', i, '.rds', sep='')
  testRDS[[i]] = readRDS(file_path)
}
# Modified `get_session_usable_data` and `get_trial_data` functions for the test
get_test_trial_data = function(session_id, trial_id){
  spikes = testRDS[[session_id]]$spks[[trial_id]]
  if (any(is.na(spikes))){
    disp("value missing")
  }

  trial_bin_average = matrix(colMeans(spikes), nrow = 1)
  colnames(trial_bin_average) = data_name
  trial_tibble  = as_tibble(trial_bin_average )%>% 
    add_column("trial_id" = trial_id) %>% 
    add_column("contrast_left"= testRDS[[session_id]]$contrast_left[trial_id]) %>% 
    add_column("contrast_right"= testRDS[[session_id]]$contrast_right[trial_id]) %>% 
    add_column("feedback_type"= testRDS[[session_id]]$feedback_type[trial_id])
  
  return(trial_tibble)
}
get_test_useable_data = function(session_id){
  n_trial = length(testRDS[[session_id]]$spks)
  trial_list = list()
  
  for (trial_id in 1:n_trial){
    trial_tibble = get_test_trial_data(session_id, trial_id)  # Fetch trial data
    trial_list[[trial_id]] = trial_tibble
  }
  
  # Combine trials into a single tibble for the session
  session_tibble = as_tibble(do.call(rbind, trial_list))
  
  # Add relevant metadata for each session
  session_tibble = session_tibble %>% 
    add_column("mouse_name" = testRDS[[session_id]]$mouse_name) %>% 
    add_column("date_exp" = testRDS[[session_id]]$date_exp) %>% 
    add_column("session_id" = session_id)
  
  return(session_tibble)
}

# Convert test sessions into a structured tibble
test_data_list = list()
for (testing_id in 1:2) {
  test_data_list[[testing_id]] = get_test_useable_data(testing_id)
}

test_data_tibble = as_tibble(do.call(rbind, test_data_list))
test_data_tibble$session_id = as.factor(test_data_tibble$session_id)
test_data_tibble$contrast_diff = abs(test_data_tibble$contrast_left - test_data_tibble$contrast_right)
test_data_tibble$success = test_data_tibble$feedback_type == 1
test_data_tibble$success = as.numeric(test_data_tibble$success)

# Display the first few rows of processed test data
head(test_data_tibble)
## # A tibble: 6 × 49
##    data1  data2  data3  data4  data5  data6  data7  data8  data9 data10 data11
##    <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
## 1 0.0327 0.0259 0.0232 0.0259 0.0300 0.0286 0.0327 0.0450 0.0436 0.0272 0.0327
## 2 0.0354 0.0354 0.0381 0.0272 0.0341 0.0518 0.0477 0.0341 0.0327 0.0232 0.0259
## 3 0.0354 0.0313 0.0245 0.0327 0.0204 0.0136 0.0300 0.0313 0.0272 0.0395 0.0490
## 4 0.0341 0.0259 0.0422 0.0259 0.0232 0.0409 0.0463 0.0422 0.0395 0.0327 0.0368
## 5 0.0313 0.0300 0.0218 0.0518 0.0695 0.0463 0.0300 0.0163 0.0232 0.0204 0.0450
## 6 0.0490 0.0436 0.0341 0.0463 0.0409 0.0381 0.0490 0.0381 0.0409 0.0313 0.0368
## # ℹ 38 more variables: data12 <dbl>, data13 <dbl>, data14 <dbl>, data15 <dbl>,
## #   data16 <dbl>, data17 <dbl>, data18 <dbl>, data19 <dbl>, data20 <dbl>,
## #   data21 <dbl>, data22 <dbl>, data23 <dbl>, data24 <dbl>, data25 <dbl>,
## #   data26 <dbl>, data27 <dbl>, data28 <dbl>, data29 <dbl>, data30 <dbl>,
## #   data31 <dbl>, data32 <dbl>, data33 <dbl>, data34 <dbl>, data35 <dbl>,
## #   data36 <dbl>, data37 <dbl>, data38 <dbl>, data39 <dbl>, data40 <dbl>,
## #   trial_id <int>, contrast_left <dbl>, contrast_right <dbl>, …

Before evaluating the test dataset, I used a similar method to the sessions dataset to load the test dataset by loading all the RDS files and passing them through modified versions of the same functions used for the sessions dataset. This is done due to the previous sessions specific functions having hardcoded limitations. Using a similar method as previously will help in streamlining the coding process.

Accuracy

test_data_X = test_data_tibble %>%
  select(-c("trial_id", "feedback_type", "mouse_name", "date_exp", "session_id", "contrast_diff", "success")) %>%
  as.matrix()

test_data_Y = test_data_tibble %>%
  select("feedback_type") %>%
  pull()  # Convert to vector

test_data_X_xgb = xgb.DMatrix(data = test_data_X)
# test_data_Y = as.numeric(ifelse(Y_test == -1, 0, 1))

# Make predictions
predictions_test = predict(xgb_model, test_data_X_xgb)

predicted_labels_test = as.numeric(ifelse(predictions_test > 0.5, 1, -1))

accuracy = mean(predicted_labels_test == test_data_Y)

print(paste("Accuracy:", accuracy))
## [1] "Accuracy: 0.715"

I’ve split the test data in a similar way to the sessions data.

I’ll use the accuracy as our first criteria to evaluate the performance of the model on the test data.

This test accuracy is slightly lower than the validation accuracy (~71.5%) but remains high, indicating that the model’s performance remains fairly consistent across different datasets. However, the slight drop suggests that the model may not generalize perfectly and could still be influenced by session-specific variations in neural activity.

Confusion Matrix

# Compute confusion matrix
conf_matrix_test = confusionMatrix(factor(predicted_labels_test),
                                    factor(test_data_Y))

print(conf_matrix_test)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  -1   1
##         -1   6   8
##         1   49 137
##                                           
##                Accuracy : 0.715           
##                  95% CI : (0.6471, 0.7764)
##     No Information Rate : 0.725           
##     P-Value [Acc > NIR] : 0.6575          
##                                           
##                   Kappa : 0.0701          
##                                           
##  Mcnemar's Test P-Value : 1.17e-07        
##                                           
##             Sensitivity : 0.1091          
##             Specificity : 0.9448          
##          Pos Pred Value : 0.4286          
##          Neg Pred Value : 0.7366          
##              Prevalence : 0.2750          
##          Detection Rate : 0.0300          
##    Detection Prevalence : 0.0700          
##       Balanced Accuracy : 0.5270          
##                                           
##        'Positive' Class : -1              
## 
# Convert confusion matrix to data frame
conf_matrix_df_test = as.data.frame(conf_matrix_test$table)

# Rename columns for clarity
colnames(conf_matrix_df_test) = c("Actual_Class", "Predicted_Class", "Frequency")

# Convert to factors for proper ordering in ggplot
conf_matrix_df_test$Actual_Class = as.factor(conf_matrix_df_test$Actual_Class)
conf_matrix_df_test$Predicted_Class = as.factor(conf_matrix_df_test$Predicted_Class)

# Plot Confusion Matrix as a Heatmap
ggplot(data = conf_matrix_df_test, aes(x = Predicted_Class, y = Actual_Class, fill = Frequency)) +
  geom_tile(color = "white") +  # Creates the heatmap
  geom_text(aes(label = Frequency), vjust = 0.5, size = 5) +  # Adds text labels
  scale_fill_gradient(low = "white", high = "red") +  # Color gradient for intensity
  labs(
    x = "Predicted Class",
    y = "Actual Class",
    title = "Confusion Matrix for XGBoost Model on Test Dataset"
  )

My second evaluation criterion is to use a confusion matrix, similar to my previous prediction model implementations, to provide a detailed breakdown of classification results beyond overall accuracy. While accuracy remains relatively high as stated earlier, it does not fully reflect the model’s difficulty in distinguishing between feedback types. Sensitivity dropped to 10.91%, indicating that the model struggles to correctly identify negative feedback (-1), while specificity remained strong at 94.48%, confirming a strong bias toward predicting positive feedback (1).

Compared to the validation phase (14.48% sensitivity, 96.39% specificity), the test results show a further decline in sensitivity and overall class differentiation, as reflected in the balanced accuracy of 52.7%. The Kappa statistic (0.0701) suggests weak agreement beyond chance, and McNemar’s test (p < 1.17e-07) reinforces a significant misclassification imbalance between the two classes.

Despite achieving the highest accuracy among models, XGBoost’s low sensitivity highlights persistent challenges in identifying negative feedback. Further refinements, such as adjusting class weights, optimizing decision thresholds, or exploring feature selection, could improve its ability to balance classification performance.

Precision, Recall, and F1 Score

# Compute precision, recall, and F1 score
precision = posPredValue(factor(predicted_labels), factor(Y_test_xgb), positive="1")
recall = sensitivity(factor(predicted_labels), factor(Y_test_xgb), positive="1")
f1_score = 2 * ((precision * recall) / (precision + recall))

cat("Precision Training:", precision, "\n")
## Precision Training: 0.7320675
cat("Recall Training:", recall, "\n")
## Recall Training: 0.9638889
cat("F1 Score Training:", f1_score, "\n")
## F1 Score Training: 0.8321343
precision = posPredValue(factor(predicted_labels_test), factor(test_data_Y), positive="1")
recall = sensitivity(factor(predicted_labels_test), factor(test_data_Y), positive="1")
f1_score = 2 * ((precision * recall) / (precision + recall))

cat("\nPrecision Test Data:", precision, "\n")
## 
## Precision Test Data: 0.7365591
cat("Recall Test Data:", recall, "\n")
## Recall Test Data: 0.9448276
cat("F1 Score Test Data:", f1_score, "\n")
## F1 Score Test Data: 0.8277946

Since XGBoost struggles with class imbalance, accuracy alone is not a reliable metric. I then decided to calculate the precision, recall, and F1-score in order to provide a more balanced evaluation by considering false positives and false negatives. Precision measures how many predicted positive cases (1) were correct, while recall (sensitivity) indicates how well the model identifies actual positive cases. The F1-score, a harmonic mean of precision and recall, balances both metrics and is crucial when misclassifications have different consequences.

Comparing training dataset and test datasets results helps assess generalization. The training F1-score (0.8321) is nearly identical to the test F1-score (0.8278), suggesting the model generalizes well. High test precision (0.7366) and recall (0.9448) indicate that the model is conservative in predicting 1, minimizing false positives while still capturing most actual positive cases. However, low recall for negative cases (-1) confirms class imbalance, as observed in previous analyses.

ROC-AUC (Receiver Operating Characteristic - Area Under the Curve)

# Calculate ROC curve and AUC
library(pROC)

roc_curve = roc(test_data_Y, predictions_test)
## Setting levels: control = -1, case = 1
## Setting direction: controls < cases
# Convert to dataframe for ggplot
roc_df = data.frame(
  FPR = 1 - roc_curve$specificities,  # False Positive Rate
  TPR = roc_curve$sensitivities       # True Positive Rate
)

# Plot ROC curve
ggplot(roc_df, aes(x = FPR, y = TPR)) +
  geom_line(color = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "red") +  # Diagonal reference line
  labs(
    title = "ROC Curve for XGBoost Model",
    x = "False Positive Rate",
    y = "True Positive Rate"
  )

auc_value = auc(roc_curve)
cat("AUC:", auc_value, "\n")
## AUC: 0.6126646

In classification problems, especially those involving imbalanced datasets, traditional metrics like accuracy can be misleading. A model may achieve high accuracy simply by favoring the majority class while failing to correctly classify the minority class, making it ineffective in real-world applications. To address this, I decided to calculate the ROC-AUC (Receiver Operating Characteristic - Area Under the Curve) as it is a more reliable evaluation metric.

The ROC-AUC provides a more comprehensive evaluation of model performance, particularly for imbalanced datasets where accuracy alone can be misleading. By measuring the trade-off between sensitivity and specificity across different thresholds, ROC-AUC helps determine how well the model differentiates between classes.

The XGBoost model achieved an AUC score of 0.6127, indicating that it performs slightly better than random guessing (0.5) but still struggles to distinguish between positive (1) and negative (-1) feedback cases. The ROC curve does not sharply bend toward the upper-left corner, reinforcing previous findings that the model favors the majority class and has difficulty capturing negative feedback patterns.

While the AUC score shows some classification capability, the model’s imbalance remains a challenge. Further improvements such as further hyperparameter tuning, class balancing, and feature selection could help enhance predictive performance, particularly in detecting negative feedback cases more effectively.

Moving on to the discussion section of this report, we’ll discuss the findings.


Discussion

This report aimed to identify the best predictive model for classifying feedback types in mice based on neural activity. I began with Exploratory Data Analysis (EDA) to examine neural spike activity patterns, success rate distributions, and differences across experimental sessions. Dimensionality reduction techniques (PCA, t-SNE) helped explore high-dimensional neural data, while clustering methods assessed whether distinct activation patterns could aid classification. After structuring raw spike data into meaningful features, we split the dataset into training and test sets and evaluated four classification models: Logistic Regression, K-Nearest Neighbors (KNN), Support Vector Machine (SVM), and XGBoost. Each model was assessed using accuracy, confusion matrices, precision, recall, F1-score, and ROC-AUC curves.

Among the models, XGBoost achieved the highest accuracy (72.47%), outperforming KNN (71.68%), Logistic Regression (71.29%), and SVM (71.09%). However, despite its improved accuracy, all models struggled with a severe class imbalance in the dataset, where positive feedback cases (1) vastly outnumbered negative feedback cases (-1). This imbalance led to low sensitivity across all models, meaning negative feedback cases were frequently misclassified. XGBoost, while the most balanced among the models, still showed sensitivity of only 10.91% and specificity of 94.48%, reflecting its bias toward predicting positive feedback. The AUC score of 0.6127 indicated that while the model was better than random guessing, it still struggled to differentiate between feedback types. McNemar’s test confirmed significant misclassification differences, further highlighting the difficulty of detecting negative feedback.

Addressing this inherent class imbalance is particularly challenging. Simple solutions such as oversampling the minority class (-1) or undersampling the majority class (1) could lead to overfitting or loss of valuable data. More advanced approaches like class weighting, cost-sensitive learning, synthetic data generation (e.g., SMOTE), or threshold adjustments could improve model performance but may still struggle to fully rectify the imbalance due to the fundamental nature of the dataset. Additionally, neural activity patterns related to negative feedback may be less distinct or harder to separate, making classification inherently more difficult regardless of the model used.

Future improvements could explore hyperparameter tuning, ensemble methods, or deep learning approaches like Recurrent Neural Networks (RNNs) or Transformer-based models to capture temporal dependencies in neural activity. While XGBoost emerged as the best-performing model, further refinements are necessary to enhance its ability to detect negative feedback, ensuring more balanced and reliable classification.


sessionInfo()
## R version 4.4.3 (2025-02-28)
## Platform: aarch64-apple-darwin20
## Running under: macOS Sequoia 15.3.2
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib 
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: America/Los_Angeles
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] xgboost_1.7.8.1 e1071_1.7-16    class_7.3-23    pROC_1.18.5    
##  [5] Rtsne_0.17      caret_7.0-1     lattice_0.22-6  lubridate_1.9.4
##  [9] forcats_1.0.0   stringr_1.5.1   purrr_1.0.2     readr_2.1.5    
## [13] tidyr_1.3.1     tibble_3.2.1    tidyverse_2.0.0 dplyr_1.1.4    
## [17] ggplot2_3.5.1  
## 
## loaded via a namespace (and not attached):
##  [1] tidyselect_1.2.1     timeDate_4041.110    farver_2.1.2        
##  [4] fastmap_1.2.0        digest_0.6.37        rpart_4.1.24        
##  [7] timechange_0.3.0     lifecycle_1.0.4      survival_3.8-3      
## [10] magrittr_2.0.3       compiler_4.4.3       rlang_1.1.4         
## [13] sass_0.4.9           tools_4.4.3          utf8_1.2.4          
## [16] yaml_2.3.10          data.table_1.16.4    knitr_1.49          
## [19] labeling_0.4.3       plyr_1.8.9           withr_3.0.2         
## [22] nnet_7.3-20          grid_4.4.3           stats4_4.4.3        
## [25] fansi_1.0.6          colorspace_2.1-1     future_1.34.0       
## [28] globals_0.16.3       scales_1.3.0         iterators_1.0.14    
## [31] MASS_7.3-64          cli_3.6.3            rmarkdown_2.29      
## [34] crayon_1.5.3         generics_0.1.3       rstudioapi_0.17.1   
## [37] future.apply_1.11.3  reshape2_1.4.4       tzdb_0.4.0          
## [40] cachem_1.1.0         proxy_0.4-27         splines_4.4.3       
## [43] parallel_4.4.3       vctrs_0.6.5          hardhat_1.4.1       
## [46] Matrix_1.7-2         jsonlite_1.8.9       hms_1.1.3           
## [49] listenv_0.9.1        foreach_1.5.2        gower_1.0.2         
## [52] jquerylib_0.1.4      recipes_1.1.1        glue_1.8.0          
## [55] parallelly_1.42.0    codetools_0.2-20     stringi_1.8.4       
## [58] gtable_0.3.6         munsell_0.5.1        pillar_1.9.0        
## [61] htmltools_0.5.8.1    ipred_0.9-15         lava_1.8.1          
## [64] R6_2.5.1             evaluate_1.0.1       bslib_0.8.0         
## [67] Rcpp_1.0.13-1        nlme_3.1-167         prodlim_2024.06.25  
## [70] mgcv_1.9-1           xfun_0.49            pkgconfig_2.0.3     
## [73] ModelMetrics_1.2.2.2

References

  1. Lecture Notes
  2. Discussion Notes
  3. ChatGPT (explanation of certain topics and getting ideas)